from IPython import display
import time as time
import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
from IPython.core.magic import register_cell_magic
import h5py


def getfiles_txt(dirpath, withstring = None):
    a = [s for s in os.listdir(dirpath)
         if os.path.isfile(os.path.join(dirpath, s))]
    a.sort(key=lambda s: os.path.getmtime(os.path.join(dirpath, s)))
    if withstring == None:
        return [i for i in a if i.endswith('.txt')]
    else:
        return [i for i in a if i.endswith('.txt') if withstring in i]

def getfiles(dirpath, extension, withstring = None):
    """
    Gets all the files in 'dirpath' wih the file extension 'extension'
    use double backslashes and add a double backslash at the end of dirpath
    """
    a = [s for s in os.listdir(dirpath)
         if os.path.isfile(os.path.join(dirpath, s))]
    a.sort(key=lambda s: os.path.getmtime(os.path.join(dirpath, s)))
    if withstring == None:
        return [i for i in a if i.endswith(extension)]
    else:
        return [i for i in a if i.endswith(extension) if withstring in i]

# récupère tout les fichier d'extension .h5

# obsolete
def getfiles_h5(dirpath, withstring = None):
    a = [s for s in os.listdir(dirpath)
         if os.path.isfile(os.path.join(dirpath, s))]
    a.sort(key=lambda s: os.path.getmtime(os.path.join(dirpath, s)))
    if withstring == None:
        return [i for i in a if i.endswith('.h5')]
    else:
        return [i for i in a if i.endswith('.h5') if withstring in i]
    
def getfiles_hdf5(dirpath, withstring = None):
    a = [s for s in os.listdir(dirpath)
         if os.path.isfile(os.path.join(dirpath, s))]
    a.sort(key=lambda s: os.path.getmtime(os.path.join(dirpath, s)))
    if withstring == None:
        return [i for i in a if i.endswith('.hdf5')]
    else:
        return [i for i in a if i.endswith('.hdf5') if withstring in i]


# obsolete
def save_h5(fullpath,datasets,group=None, overwrite=False):
    
    h5py.get_config().track_order = True
    
    if overwrite:
        stringh5 = 'w'
    else:
        stringh5 = 'a'
    with h5py.File(fullpath, stringh5, track_order=True) as fileH5:  # open file in append mode
        if group:
            g = fileH5.create_group(str(group), track_order=True)  # create a data group corresponding to the sweeped parameter
            for key in datasets:
                g.create_dataset(str(key), data=datasets[key])  # create a dataset corresponding to the click array
        else:
            for key in datasets:
                fileH5.create_dataset(str(key), data=datasets[key])  # create a dataset corresponding to the click array


def load_h5(fullpath):
    file = h5py.File(fullpath, 'r')
    main_keys = list(file.keys())

    data_vector = []

    if isinstance(file[main_keys[0]], h5py.Dataset):
        for key in main_keys:
            data_vector.append(file[key])
        return np.array(main_keys), np.array(data_vector)
    else:
        for j, key in enumerate(main_keys):
            datasets_keys = list(file[key].keys())
            data_vector.append([])
            for d_key in datasets_keys:
                data_vector[j].append(file[key][d_key])
        return np.array(main_keys), np.array(datasets_keys), np.array(data_vector)
    file.close()

def load_h5_v2(fullpath):
    with h5py.File(fullpath, 'r') as file:
        main_keys = list(file.keys())

        data_vector = []

        if isinstance(file[main_keys[0]], h5py.Dataset):
            for key in main_keys:
                data_vector.append(file[key])
            return np.array(main_keys), np.array(data_vector)
        else:
            for j, key in enumerate(main_keys):
                datasets_keys = list(file[key].keys())
                data_vector.append([])
                for d_key in datasets_keys:
                    data_vector[j].append(file[key][d_key])
            return np.array(main_keys), np.array(datasets_keys), np.array(data_vector) 
    
## load into a dictionary together with all the keys
def load_h5_to_dic(fullpath):
    with h5py.File(fullpath, 'r') as file:
        main_keys = list(file["/"].keys())
        data_vector = {}
        if isinstance(file[main_keys[0]], h5py.Dataset):
            #datasets_keys_list = [main_keys]
            for key in main_keys:
                data_vector[key]=file[key][()]
            return data_vector, main_keys
        else:
            datasets_keys_list = {}
            for j, key in enumerate(main_keys):
                datasets_keys = list(file[key].keys())
                datasets_keys_list[key]=list(file[key].keys())
                data_vector[key]={}
                for d_key in datasets_keys:
                    data_vector[key][d_key]=file[key][d_key][()]
            return data_vector, datasets_keys_list 
    
#try:
#    mwg
#except:
#    mwg = {}
#load_MWG(instruments, mwg)

#update_MWG(instruments, mwg)
#status_MWG(instruments, mwg)

time_stamp_old = None


## Function adapted for a data structure different for the hdf5, such as the one from D:\Erbium_SMPD\20210729-run6\02 SpinDetection60p2mT\Spin_detection_vs_power_vs_B0\20210730190401_\amplitude_0.1
def load_h5_2subgroups(fullpath):
    file = h5py.File(fullpath, 'r')

    main_keys_list = list(file)

    data_vector = []

    for i, main_keys in enumerate(main_keys_list):

        data_vector.append([])

        second_keys_list = list(file[main_keys])

        for j, second_keys in enumerate(second_keys_list):

            datasets_keys = list(file[main_keys][second_keys])
            data_vector[i].append([])

            for d_key in datasets_keys:

                data_vector[i][j].append(file[main_keys][second_keys][d_key])
                
    return (np.array(second_keys_list), np.array(datasets_keys), np.array(data_vector))




@register_cell_magic
def write_and_run(line, cell):
    argz = line.split()
    file = argz[-1]
    mode = 'w'
    if len(argz) == 2 and argz[0] == '-a':
        mode = 'a'
    with open(file, mode) as f:
        f.write(cell)
    get_ipython().run_cell(cell)


def update_progress(progress):
    barLength = 80  # Modify this to change the length of the progress bar
    if progress >= 0.99:
        progress = 1
        status = "     Done!\r\n"
    else:
        status = "Running..."
    block = int(round(barLength * progress))
    text = "\rProgress: [{0}] {1}% {2}".format("=" * block + " " * (barLength - block), int(progress * 100), status)
    print(text, end="\r")


def get_timestamp():
    x = datetime.datetime.now()
    return x.strftime("%Y%m%d%H%M%S_")

def make_exp_directory(path,experiment_name):
    directory=path+experiment_name+'\\'
    if not os.path.exists(directory):
        os.makedirs(directory)
        print('directory created')
    return directory

def QUA_timer(job, N_iterations):
    index_handle = job.result_handles.get('interation')
    index = 0
    last_index = 0
    while index < N_iterations - 1:
        if index_handle.count_so_far() > 0:
            index = index_handle.fetch_all()
            time.sleep(1)
            if int(index / N_iterations * 100) != int(last_index / N_iterations * 100):
                last_index = index
                print(int(index / N_iterations * 100), '%')

def QUA_timer_simple(job, N_iterations):
    index_handle = job.result_handles.get('interation')
    if index_handle.count_so_far() > 0:
        index = index_handle.fetch_all()
        update_progress(index / N_iterations)

def QUA_live_data(job, N_iterations, x_data, correct_edelay=None, JupyterLab=True):
    
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    if not JupyterLab:
        fig.show()
        plt.ion()
        fig.show()
        fig.canvas.draw()

    I_handle = job.result_handles.get('Iav_vs_f')
    Q_handle = job.result_handles.get('Qav_vs_f')
    index_handle = job.result_handles.get('interation')
    index = 0
    last_index = 0

    while index < N_iterations - 1:
        if index_handle.count_so_far() > 0:
            I = I_handle.fetch_all()
            Q = Q_handle.fetch_all()
            index = index_handle.fetch_all()

            if JupyterLab:
                display.clear_output(wait=True)

            if int(index / N_iterations * 100) != int(last_index / N_iterations * 100):
                last_index = index
                update_progress(index / N_iterations)

            if correct_edelay is not None:
                t_delay = correct_edelay  # 320e-9
                IQ_corrected = np.array((I - 1j * Q) * np.exp(-1j * 2 * np.pi * x_data * t_delay))
                I_new = np.real(IQ_corrected)
                Q_new = np.imag(IQ_corrected)
                mag = np.sqrt(I_new ** 2 + Q_new ** 2)
                phase = np.arctan(I_new / Q_new)
                ax[0].clear()
                ax[1].clear()
                ax[0].plot(x_data, mag)
                ax[1].plot(x_data, phase)
                ax[0].set_ylabel('Mag')
                ax[1].set_ylabel('Phase')

            else:
                ax[0].clear()
                ax[1].clear()
                ax[0].plot(x_data, I)
                ax[1].plot(x_data, Q)
                ax[0].set_ylabel('I')
                ax[1].set_ylabel('Q')

            if JupyterLab:
                display.display(plt.gcf())
            else:
                fig.canvas.draw()

            time.sleep(1)


def QUA_live_data_2D_EA(job, N_iterations,params,streams):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))


    p_stream_handle=[]
    for stream in streams:
        p_stream_handle.append(job.result_handles.get(stream))

    index_handle = job.result_handles.get('interation')
    index = 0
    last_index = 0

    while index < N_iterations - 1:
        if index_handle.count_so_far() > 0:
            # time.sleep(1)

            p_stream=[]
            for j,stream in enumerate(streams):
                p_stream.append(np.reshape(np.array(p_stream_handle[j].fetch_all()), (len(params[0]), len(params[1]))))
            index = index_handle.fetch_all()


            if int(index / N_iterations * 100) != int(last_index / N_iterations * 100):
                last_index = index
                update_progress(index / N_iterations)


            ax[0].pcolorfast(params[1],params[0],p_stream[0])
            pc1 = ax[1].pcolorfast(params[1],params[0],p_stream[1])
            display.clear_output(wait=True)
            display.display(plt.gcf())

def QUA_live_data_2D(job, N_iterations,n_steps1,n_steps2):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))



    p_ON_handle = job.result_handles.get('p_ON')
    p_OFF_handle = job.result_handles.get('p_OFF')
    index_handle = job.result_handles.get('interation')
    index = 0
    last_index = 0

    while index < N_iterations - 1:
        if index_handle.count_so_far() > 0:
            time.sleep(1)

            p_ON = np.reshape(np.array(p_ON_handle.fetch_all()), (n_steps1, n_steps2))
            p_OFF = np.reshape(np.array(p_OFF_handle.fetch_all()), (n_steps1, n_steps2))
            index = index_handle.fetch_all()




            display.clear_output(wait=True)

            if int(index / N_iterations * 100) != int(last_index / N_iterations * 100):
                last_index = index
                update_progress(index / N_iterations)


            c = ax[0].pcolor(p_ON)
            # fig.colorbar(c, ax=ax[0])

            c = ax[1].pcolor(p_OFF)
            # fig.colorbar(c, ax=ax[1])

            display.display(plt.gcf())


def QUA_live_histo(job, N_iterations, JupyterLab=True):
    I_ON_handle = job.result_handles.get('I_ON')
    Q_ON_handle = job.result_handles.get('Q_ON')
    I_OFF_handle = job.result_handles.get('I_OFF')
    Q_OFF_handle = job.result_handles.get('Q_OFF')
    index_handle = job.result_handles.get('interation')
    index = 0
    last_index = 0

    fig, ax = plt.subplots(2, 3, figsize=(10, 5))

    if not JupyterLab:
        plt.ion()
        fig.show()
        fig.canvas.draw()

    flag = True
    while index < N_iterations - 1:

        count=I_ON_handle.count_so_far()
        if count > 0:

            index = index_handle.fetch_all()


            I_ON = np.array(I_ON_handle.fetch(count-1)[0][0])
            Q_ON = np.array(Q_ON_handle.fetch(count-1)[0][0])
            I_OFF = np.array(I_OFF_handle.fetch(count-1)[0][0])
            Q_OFF = np.array(Q_OFF_handle.fetch(count-1)[0][0])
            ax[0, 0].clear()
            ax[1, 0].clear()
            ax[0, 1].clear()
            ax[1, 1].clear()
            ax[0, 2].clear()
            ax[1, 2].clear()

            if flag:
                x_min = min(min(I_ON), min(I_OFF))
                y_min = min(min(Q_ON), min(Q_OFF))
                x_max = max(max(I_ON), max(I_OFF))
                y_max = max(max(Q_ON), max(Q_OFF))
                flag = False

            ax[0, 0].hist2d(I_ON, Q_ON, bins=31, range=[[x_min, x_max], [y_min, y_max]])
            ax[1, 0].hist2d(I_OFF, Q_OFF, bins=31, range=[[x_min, x_max], [y_min, y_max]])
            ax[0, 0].set_aspect(1)
            ax[1, 0].set_aspect(1)
            hist_I_ON, bins_I_ON = np.histogram(I_ON, bins=51, range=[x_min, x_max])
            hist_I_OFF, bins_I_OFF = np.histogram(I_OFF, bins=51, range=[x_min, x_max])

            hist_Q_ON, bins_Q_ON = np.histogram(Q_ON, bins=51, range=[y_min, y_max])
            hist_Q_OFF, bins_Q_OFF = np.histogram(Q_OFF, bins=51, range=[y_min, y_max])

            ax[0, 1].plot(bins_I_ON[:-1], hist_I_ON)
            ax[0, 1].plot(bins_I_OFF[:-1], hist_I_OFF)

            ax[1, 1].plot(bins_Q_ON[:-1], hist_Q_ON)
            ax[1, 1].plot(bins_Q_OFF[:-1], hist_Q_OFF)

            ax[0, 2].semilogy(bins_I_ON[:-1], hist_I_ON)
            ax[0, 2].semilogy(bins_I_OFF[:-1], hist_I_OFF)

            ax[1, 2].semilogy(bins_Q_ON[:-1], hist_Q_ON)
            ax[1, 2].semilogy(bins_Q_OFF[:-1], hist_Q_OFF)

            time.sleep(1)

        if JupyterLab:
            display.clear_output(wait=True)

        if int(index / N_iterations * 100) != int(last_index / N_iterations * 100):
            last_index = index
            update_progress(index / N_iterations)
            # print('progress: '+str(['-']*int(index / N_iterations * 100)))
            # print('progress: '+str(int(index / N_iterations * 100))+'%', end="\r")

        if JupyterLab:
            display.display(plt.gcf())
        else:
            fig.canvas.draw()

